import pandas as pd
import os
import sys
import random
from torchvision import models, transforms
from models.model_baseline_cremad import BaselineModel
from models.model_midas_cremad import MidasModel
from datasets.Cremad_dataset import CREMADDataset
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything
import pytorch_lightning as pl
import torch
import json
from PIL import Image
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import argparse

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT_DIR)

os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.set_num_threads(1)
torch.set_float32_matmul_precision('high')

parser = argparse.ArgumentParser(description='Classification')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--devices', type=str, default='5', help='GPU device ids (comma-separated, e.g., "0,1")')
parser.add_argument('--method', type=str, default='Baseline')
args = parser.parse_args()

seed = args.seed
seed_everything(seed, workers = True)

random.seed(seed)        
np.random.seed(seed)     
torch.manual_seed(seed)   
torch.cuda.manual_seed(seed)  
torch.cuda.manual_seed_all(seed)  
torch.backends.cudnn.deterministic = True  
torch.backends.cudnn.benchmark = False

if __name__ == "__main__":
    dataset_path = './cremad'
    train_path = dataset_path +  "/train_split.csv"
    dev_path = dataset_path +  "/val.csv"
    state_path = dataset_path +  "/stat.csv"
    test_path = dataset_path +  "/test.csv"
    

    train_df = pd.read_csv(train_path, sep = ',')
    dev_df = pd.read_csv(dev_path, sep = ',')
    test_df = pd.read_csv(test_path, sep = ',')


    combined_df = pd.concat([train_df, dev_df, test_df], ignore_index=True)

    train_df, temp_df = train_test_split(
        combined_df,
        test_size=0.3,
        random_state=seed,
        stratify=combined_df['label']
    )

    dev_df, test_df = train_test_split(
        temp_df,
        test_size=2/3,  
        random_state=seed,
        stratify=temp_df['label']
    )

    print("Train:", train_df.shape)
    print("Dev:", dev_df.shape)
    print("Test:", test_df.shape)

    train_split_path = './train_split_cremad_' + str(seed) + '.csv'
    dev_split_path = './dev_split_cremad_' + str(seed) + '.csv'
    test_split_path = './test_split_cremad_' + str(seed) + '.csv'

    train_df.to_csv(train_split_path, sep = ',', index=False)
    dev_df.to_csv(dev_split_path , sep = ',', index=False)
    test_df.to_csv(test_split_path, sep = ',', index=False)

    hparams = {
        # Required hparams
        "train_path": train_split_path,
        "dev_path": dev_split_path,
        "stat_path": state_path,
        "audio_dir": dataset_path + '/AudioWAV',
        "video_dir": dataset_path + '/visual_frames',
        "dataset": "cremad",
        "embedding_dim": 150,
        "audio_feature_dim": 512,
        "video_feature_dim": 512,
        "fusion_output_size": 256,
        "output_path": "model-outputs-cremad-" + args.method,
        "dev_limit": None,
        "weight_decay": 1e-4,
        "lr": 1e-3,
        "max_epochs": 70,
        "n_gpu": 1,
        "num_workers": 5,
        "warmup_epochs": 5,
        "devices": [int(device) for device in args.devices.split(',')],
        "batch_size": 64,
        # allows us to "simulate" having larger batches 
        "accumulate_grad_batches": 1,
        "early_stop_patience": 100,
        "use_augmentation": False,
        "method": args.method,
        "num_classes": 6
    }
    print("method:", args.method)

    def get_model(args, hparams):
        if args.method == 'Baseline':
            return BaselineModel(params=hparams)
        elif args.method == 'Midas':
            hparams["warmup_epochs"] = 10
            hparams["use_augmentation"] = True
            return MidasModel(params=hparams)
        else:
            raise ValueError(f"Unsupported model: {args.method}")

    model = get_model(args, hparams)
    model.fit()

    test_dataset = CREMADDataset(
        mode='test',
        audio_path=hparams["audio_dir"],
        visual_path=hparams["video_dir"],
        stat_csv = hparams["stat_path"],
        test_csv = test_split_path
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=hparams["batch_size"], 
        shuffle=False, 
        num_workers=5, 
    )
    model.trainer.test(ckpt_path="best", dataloaders=test_loader)
